{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from DaSiamRPN.net import SiamRPNBIG\n",
    "from DaSiamRPN.utils import get_axis_aligned_bbox, cxy_wh_2_rect\n",
    "from DaSiamRPN.run_SiamRPN import SiamRPN_init, SiamRPN_track\n",
    "import cv2 as cv\n",
    "import matplotlib.pyplot as plt\n",
    "import json\n",
    "import random\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "from gen_seq import gen_seq\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DaSiamRPN():\n",
    "    def __init__(self):\n",
    "        pass\n",
    "\n",
    "    def load_model(self, net_file):\n",
    "        self.net = SiamRPNBIG()\n",
    "        self.net.load_state_dict(torch.load(net_file, map_location=\"cpu\"))\n",
    "        self.net.eval().cpu()\n",
    "\n",
    "        for i in range(10):\n",
    "            self.net.temple(torch.autograd.Variable(\n",
    "                torch.FloatTensor(1, 3, 127, 127)).cpu())\n",
    "            self.net(torch.autograd.Variable(\n",
    "                torch.FloatTensor(1, 3, 255, 255)).cpu())\n",
    "\n",
    "    def init_bbox(self, image, bbox):\n",
    "        target_pos = np.array(bbox[0:2])\n",
    "        target_size = np.array(bbox[2:4])\n",
    "        self.state = SiamRPN_init(image, target_pos, target_size, self.net)\n",
    "\n",
    "    def update(self, image):\n",
    "        self.state = SiamRPN_track(self.state, image)\n",
    "        target_pos = self.state['target_pos']\n",
    "        target_size = self.state['target_sz']\n",
    "        bbox = [target_pos[0], target_pos[1], target_size[0], target_size[1]]\n",
    "        return bbox"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate sequence config\n",
    "img_list, init_bbox, gt = gen_seq(\"./dataset/OTB/Football/img\", \"./dataset/OTB/Football/groundtruth_rect.txt\")\n",
    "\n",
    "# Tracker\n",
    "tracker = DaSiamRPN()\n",
    "tracker.load_model(\"./DaSiamRPN/SiamRPNBIG.model\")\n",
    "\n",
    "# init\n",
    "image = cv.imread(img_list[0])\n",
    "image = cv.cvtColor(image, cv.COLOR_BGR2RGB)\n",
    "\n",
    "tracker.init_bbox(image, gt[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SiamRPNBIG(\n",
      "  (featureExtract): Sequential(\n",
      "    (0): Conv2d(3, 192, kernel_size=(11, 11), stride=(2, 2))\n",
      "    (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (2): ReLU(inplace)\n",
      "    (3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (4): Conv2d(192, 512, kernel_size=(5, 5), stride=(1, 1))\n",
      "    (5): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (6): ReLU(inplace)\n",
      "    (7): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (8): Conv2d(512, 768, kernel_size=(3, 3), stride=(1, 1))\n",
      "    (9): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (10): ReLU(inplace)\n",
      "    (11): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1))\n",
      "    (12): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (13): ReLU(inplace)\n",
      "    (14): Conv2d(768, 512, kernel_size=(3, 3), stride=(1, 1))\n",
      "    (15): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "  )\n",
      "  (conv_r1): Conv2d(512, 10240, kernel_size=(3, 3), stride=(1, 1))\n",
      "  (conv_r2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1))\n",
      "  (conv_cls1): Conv2d(512, 5120, kernel_size=(3, 3), stride=(1, 1))\n",
      "  (conv_cls2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1))\n",
      "  (regress_adjust): Conv2d(20, 20, kernel_size=(1, 1), stride=(1, 1))\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "net = tracker.state[\"net\"]\n",
    "print(net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "hookconv_r1.weight\n"
     ]
    }
   ],
   "source": [
    "for item in net.named_parameters():\n",
    "    if item[0] == 'conv_r1.weight':\n",
    "        print(\"hook{}\".format(item[0]))\n",
    "        h = item[1].register_hook(lambda grad: print(grad))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[307.927241563797, 99.54678416252136, 40.173439936770535, 48.45516105214102]\n"
     ]
    }
   ],
   "source": [
    "image = cv.imread(img_list[1])\n",
    "image = cv.cvtColor(image, cv.COLOR_BGR2RGB)\n",
    "\n",
    "# update\n",
    "result_bbox = tracker.update(image)\n",
    "print(result_bbox)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def draw_rect(image, rect, color):\n",
    "    x = int(rect[0])\n",
    "    y = int(rect[1])\n",
    "    width = int(rect[2])\n",
    "    height = int(rect[3])\n",
    "\n",
    "    if color == \"green\":\n",
    "        color = (0, 255, 0)\n",
    "    elif color == \"red\":\n",
    "        color = (255, 0, 0)\n",
    "    elif color == \"blue\":\n",
    "        color = (0, 0, 255)\n",
    "\n",
    "    cv.rectangle(image, (x, y), (x+width, y+height), color)\n",
    "    return image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "ename": "ZeroDivisionError",
     "evalue": "float division by zero",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mZeroDivisionError\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-19-1649a12d8cb1>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     15\u001b[0m     \u001b[0mdpi\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m.0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m     \u001b[0mfigsize\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mimage\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mdpi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mimage\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mdpi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     17\u001b[0m     \u001b[0mfig\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfigure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframeon\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfigsize\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfigsize\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdpi\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdpi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     18\u001b[0m     \u001b[0max\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAxes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;36m0.\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0.\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1.\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1.\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mZeroDivisionError\u001b[0m: float division by zero"
     ]
    }
   ],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    # Generate sequence config\n",
    "    img_list, init_bbox, gt = gen_seq(seq='Football')\n",
    "\n",
    "    # Tracker\n",
    "    tracker = DaSiamRPN()\n",
    "    tracker.load_model(\"./DaSiamRPN/SiamRPNBIG.model\")\n",
    "\n",
    "    # init\n",
    "    image = cv.imread(img_list[0])\n",
    "    image = cv.cvtColor(image, cv.COLOR_BGR2RGB)\n",
    "    \n",
    "    tracker.init_bbox(image, gt[0])\n",
    "    \n",
    "    dpi = .0\n",
    "    figsize = (image.shape[0] / dpi, image.shape[1] / dpi)\n",
    "    fig = plt.figure(frameon=False, figsize=figsize, dpi=dpi)\n",
    "    ax = plt.Axes(fig, [0., 0., 1., 1.])\n",
    "    ax.set_axis_off()\n",
    "    fig.add_axes(ax)\n",
    "    im = ax.imshow(image)\n",
    "\n",
    "    # Run tracker\n",
    "    for img_file, ground_truth in zip(img_list[1:4], gt[1:4]):\n",
    "        image = cv.imread(img_file)\n",
    "        image = cv.cvtColor(image, cv.COLOR_BGR2RGB)\n",
    "\n",
    "        # draw init\n",
    "        image_draw = image.copy()\n",
    "        image_draw = draw_rect(image_draw, ground_truth, \"red\")\n",
    "\n",
    "        # update\n",
    "        result_bbox = tracker.update(image)\n",
    "        print(result_bbox)\n",
    "\n",
    "        image_draw = draw_rect(image_draw, result_bbox, \"blue\")\n",
    "\n",
    "        # image_draw = cv.cvtColor(image_draw, cv.COLOR_RGB2BGR)\n",
    "        # cv.imshow(\"image\", image_draw)\n",
    "        # cv.waitKey(0)\n",
    "        # plt.imshow(image_draw)\n",
    "        fig = plt.figure(figsize=figsize, dpi=dpi)\n",
    "        ax = plt.Axes(fig, [0., 0., 1., 1.])\n",
    "        fig.add_axes(ax)\n",
    "        ax.imshow(image_draw)\n",
    "        plt.show()\n",
    "\n",
    "    # Save result\n",
    "    # res = {}\n",
    "    # res['res'] = result_bb.round().tolist()\n",
    "    # res['type'] = 'rect'\n",
    "    # res['fps'] = fps\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
